### Simulation + Estimation of multiple network replicates ###
## Inputs
#'  n_node: Integer; number of nodes in net
#'  n_sim: Int.; number of simulated networks to generate
#'  seed: Int.; seed of RNG
#'  B: Matrix; true network blockmodel
#'  mu_b: Vector; 2-element prior mean for in-group (1st) and out-of-group (2nd)
#'        log-odds of a connection.  
#'  alpha: Double; alpha parameter for CG initialization strategy. 
#'  lda_init: Boolean; should CG be used as initialization strategy for mmsbm()?
#'  beta_arr: Array; 2x2x2 (block x monadic pred x HMM state) array of
#'            monadic coefficients
#'  directed: Boolean; should simulated net. be directed?
#'  type: String; Type of scenario being simulated.
#'  hess: Boolean; Should Hessian matrix of estimated hyper-param be computed?
#'  f.dyad: Formula; Dyadic formula passed to mmsbm()
#'  f.monad: Formula; Monadic formula passed to mmsbm() 
#'  nets_only: Boolean; Should the function return simulated networks only (without
#'             fitting models to them)? 
#'  SynthNets: List; list of simulated networks (as returned by using nets_only=TRUE).
#'  
## Output
#' List of simulated networks (if nets_only=TRUE) or estimated models,
#' with an added dataframe of predicted vs. true mixed-membership values
#' under 'pred_data'.
#'

GenSimData <- function(n_node = 100,
                       n_sim = 1,
                       seed = 831213, 
                       B,
                       mu_b = c(.5, -1.0),
                       alpha = 0.5,
                       lda_init = TRUE,
                       beta_arr,
                       directed = TRUE,
                       type = "Realistic",
                       hess = FALSE,
                       f.dyad = Y ~ V1,
                       f.monad = ~ V2,
                       tol=1e-3,
                       nets_only=FALSE,
                       SynthNets=NULL){
  set.seed(seed)
  if(is.null(SynthNets)){
    first_net <- NetSim2(BLK = 2
                         ,NODE = n_node
                         ,STATE = 2
                         ,TIME = 9
                         ,DIRECTED = directed
                         ,N_PRED=c(1, 1)
                         ,B =  B
                         ,beta_arr = beta_arr
                         ,sVec = rep(c(1,2),c(5,4))
                         ,gamma_vec = c(0.1))
    X_set <- first_net$X
    Z_set <- first_net$Z
    SynthNets <- replicate(n_sim, NetSim2(BLK = 2
                                          ,NODE = n_node
                                          ,STATE = 2
                                          ,TIME = 9
                                          ,DIRECTED = directed
                                          ,N_PRED=c(1, 1)
                                          ,B =  B
                                          ,X = X_set
                                          ,Z = Z_set
                                          ,beta_arr = beta_arr
                                          ,sVec = rep(c(1,2),c(5,4))
                                          ,gamma_vec = c(0.1)),
                           simplify = FALSE)
  } else {
    if(length(SynthNets) != n_sim){
      stop("Number of elements in SynthNets != n_sim.")
    }
  }
  if(nets_only){
    return(SynthNets)
  }
  net3.model.sim <- vector("list",n_sim)
  for(i in 1:n_sim){
    cat("\t Running model",i,"\n")
    net3.model.sim[[i]] <- mmsbm(formula.dyad = f.dyad,
                                 formula.monad = f.monad,
                                 senderID = "node1",
                                 receiverID = "node2",
                                 nodeID = "node",
                                 timeID = "year",
                                 data.dyad = SynthNets[[i]]$dyad.data,
                                 data.monad = SynthNets[[i]]$monad.data,
                                 n.blocks = SynthNets[[i]]$BLK,
                                 n.hmmstates = SynthNets[[i]]$STATE,
                                 directed = SynthNets[[i]]$DIRECTED,
                                 mmsbm.control = list(mu_b = mu_b,
                                                      var_b = c(1, 1),
                                                      var_beta = 1,
                                                      eta = 1.3,
                                                      init_gibbs = lda_init,
                                                      alpha = alpha,
                                                      hessian = hess,
                                                      seed=seed,
                                                      verbose = FALSE,
                                                      vi_iter = 5000,
                                                      conv_tol = tol
                                 ))
    loss.mat.phi<- net3.model.sim[[i]]$MixedMembership %*% SynthNets[[i]]$pi_vecs
    phi_ord <- clue::solve_LSAP(t(loss.mat.phi), TRUE)
    orig_mm_names <- rownames(net3.model.sim[[i]]$MixedMembership)
    net3.model.sim[[i]]$MixedMembership <- net3.model.sim[[i]]$MixedMembership[phi_ord, ]
    rownames(net3.model.sim[[i]]$MixedMembership) <- orig_mm_names
    orig_g_names <- dimnames(net3.model.sim[[i]]$BlockModel)
    net3.model.sim[[i]]$BlockModel <- net3.model.sim[[i]]$BlockModel[phi_ord, phi_ord]
    dimnames(net3.model.sim[[i]]$BlockModel) <- orig_g_names
    loss.mat.kappa <- net3.model.sim[[i]]$Kappa %*% model.matrix(~-1+as.factor(SynthNets[[i]]$sVec))
    kappa_ord <- clue::solve_LSAP(t(loss.mat.kappa), TRUE)
    orig_coef_names <- dimnames(net3.model.sim[[i]]$MonadCoef)
    net3.model.sim[[i]]$MonadCoef <- net3.model.sim[[i]]$MonadCoef[, phi_ord, kappa_ord, drop = FALSE]
    dimnames(net3.model.sim[[i]]$MonadCoef) <- orig_coef_names
    orig_k_names <- rownames(net3.model.sim[[i]]$Kappa)
    net3.model.sim[[i]]$Kappa <- net3.model.sim[[i]]$Kappa[kappa_ord, , drop = FALSE]
    rownames(net3.model.sim[[i]]$Kappa) <- orig_k_names
    orig_tk_names <- dimnames(net3.model.sim[[i]]$TransitionKernel)
    net3.model.sim[[i]]$TransitionKernel <- net3.model.sim[[i]]$TransitionKernel[kappa_ord, kappa_ord, drop = FALSE]
    dimnames(net3.model.sim[[i]]$TransitionKernel) <- orig_tk_names
    net3.model.sim[[i]]$pred_data <- data.frame(Network = rep(c(type), each=with(SynthNets[[i]],NODE*BLK*TIME)),
                                                Group = factor(rep(c(1:SynthNets[[i]]$BLK), each = with(SynthNets[[i]],NODE*TIME))),
                                                Year = rep(paste("Year",1:SynthNets[[i]]$TIME), each=SynthNets[[i]]$NODE, times=SynthNets[[i]]$BLK),
                                                State = rep(c("State 1","State 2"),c(5*SynthNets[[i]]$NODE,4*SynthNets[[i]]$NODE)),
                                                V2 = rep(SynthNets[[i]]$monad.data[,"V2"], SynthNets[[i]]$BLK),
                                                Truth = c(SynthNets[[i]]$pi_vecs),
                                                Pred = c(t(net3.model.sim[[i]]$MixedMembership[phi_ord,])))
  }
  return(net3.model.sim)
}